import zmq
import time
import random
import Queue
import threading
import msgpack
import json
import snappy

def _unpack_msgpack_snappy(str) :
    if str[0] == 'S':
        tmp = snappy.uncompress(str[1:])
        obj = msgpack.loads(tmp)
    elif str[0] == '\0':
        obj = msgpack.loads(str[1:])
    else:
        return None

    return obj

def _pack_msgpack_snappy(obj) :
    tmp = msgpack.dumps(obj)
    if len(tmp) > 1000:
        return 'S'  + snappy.compress(tmp)
    else:
        return '\0' + tmp

def _unpack_msgpack(str) :
    return msgpack.loads(str)

def _pack_msgpack(obj) :
    return msgpack.dumps(obj)

def _unpack_json(str) :
    return json.loads(str)

def _pack_json(obj) :
    return json.dumps(obj)

class JRpcClient :
    
    def __init__(self, data_format="msgpack_snappy") :
        self._waiter_lock = threading.Lock()        
        self._waiter_map = {}

        self._should_close = False
        self._next_callid = 0
        self._send_lock = threading.Lock()
        self._callid_lock = threading.Lock()

        self._last_heartbeat_rsp_time = 0
        self._connected = False

        self.on_disconnected = None
        self.on_rpc_callback = None
        self._callback_queue = Queue.Queue()
        self._call_wait_queue = Queue.Queue()

        self._ctx = zmq.Context()
        self._pull_sock = self._ctx.socket(zmq.PULL)
        self._pull_sock.bind("inproc://pull_sock")
        self._push_sock = self._ctx.socket(zmq.PUSH)
        self._push_sock.connect("inproc://pull_sock")

        self._heartbeat_interval = 1
        self._heartbeat_timeout = 3

        self._addr = None
        
        if data_format == "msgpack_snappy":            
            self._pack   = _pack_msgpack_snappy
            self._unpack = _unpack_msgpack_snappy

        elif data_format == "msgpack":
            self._pack   = _pack_msgpack
            self._unpack = _unpack_msgpack
        
        elif data_format == "json":
            self._pack   = _pack_json
            self._unpack = _unpack_json
        
        else:
            assert False, "unknown data_format " + data_format
        

        t = threading.Thread(target=self._recv_run)
        t.setDaemon(True)
        t.start()

        t = threading.Thread(target=self._callback_run)
        t.setDaemon(True)
        t.start()
        
    def __del__(self):
        self.close()


    def next_callid(self):
        self._callid_lock.acquire()
        self._next_callid += 1
        callid = self._next_callid
        self._callid_lock.release()
        return callid

    def set_heartbeat_options(self, interval, timeout):
        self._heartbeat_interval = interval
        self._heartbeat_timeout = timeout

    def _recv_run(self):

        heartbeat_time = 0

        poller = zmq.Poller()
        poller.register(self._pull_sock, zmq.POLLIN)

        remote_sock = None

        while not self._should_close:

            try:
                if self._connected and time.time() - self._last_heartbeat_rsp_time > self._heartbeat_timeout:
                    self._connected = False
                    if self.on_disconnected: self._async_call(self.on_disconnected)

                if remote_sock and time.time() - heartbeat_time > self._heartbeat_interval :
                    self._send_hearbeat()
                    heartbeat_time = time.time()

                socks = dict(poller.poll(500))
                if self._pull_sock in socks and socks[self._pull_sock] == zmq.POLLIN:
                    cmd = self._pull_sock.recv()
                    if cmd == "CONNECT":
                        # print time.ctime(), "CONNECT " + self._addr
                        if remote_sock:
                            poller.unregister(remote_sock)
                            remote_sock.close()
                            remote_sock = None

                        remote_sock = self._do_connect()

                        if remote_sock :
                            poller.register(remote_sock, zmq.POLLIN)

                    elif cmd.startswith("SEND:") and remote_sock :
                        #print time.ctime(), "SEND " + cmd[5:]
                        remote_sock.send(cmd[5:])

                if remote_sock and remote_sock in socks and socks[remote_sock] == zmq.POLLIN:
                    data = remote_sock.recv()
                    if data:
                        #if not data.find("heartbeat"):
                        #    print time.ctime(), "RECV", data
                        self._on_data_arrived(str(data))

            except zmq.error.Again, e:
                #print "RECV timeout: ", e
                pass
            except Exception, e:
                print("_recv_run:", e)

    def _callback_run(self):
        while not self._should_close:
            try:
                r = self._callback_queue.get(timeout = 1)
                if r :
                    r()
            except Queue.Empty, e:
                pass

            except Exception, e:
                print "_callback_run {}".format(r), type(e), e

    def _async_call(self, func):
        self._callback_queue.put( func )

    def _send_request(self, json) :

        try:
            self._send_lock.acquire()
            self._push_sock.send("SEND:" + json)

        finally:
            self._send_lock.release()
            
    def connect(self, addr) :
        self._addr = addr
        self._push_sock.send("CONNECT")


    def _do_connect(self):

        client_id = str(random.randint(1000000, 100000000))

        socket = self._ctx.socket(zmq.DEALER)
        socket.identity = str(client_id) + '$' + str(random.randint(1000000, 1000000000))
        socket.setsockopt(zmq.RCVTIMEO, 500)
        socket.setsockopt(zmq.SNDTIMEO, 500)
        socket.setsockopt(zmq.LINGER, 0)
        socket.connect(self._addr)

        return socket

    def close(self):
        self._should_close = True
                
    def _on_data_arrived(self, str):
        try:
            msg = self._unpack(str)
            #print "RECV", msg

            if not msg:
                print "wrong message format"
                return

            if msg.has_key('method') and msg['method'] == '.sys.heartbeat':
                self._last_heartbeat_rsp_time = time.time()
                if not self._connected:
                    self._connected = True
                    if self.on_connected :
                        self._async_call(self.on_connected)

                # Let user has a chance to check message in .sys.heartbeat
                if msg.has_key('result') and self.on_rpc_callback :
                    self._async_call( lambda: self.on_rpc_callback(msg['method'], msg['result']) )
                
            elif msg.has_key('id') and msg['id'] :

                # Call result
                id = int(msg['id'])
                
                if self._waiter_lock.acquire():
                    if self._waiter_map.has_key(id):
                        q = self._waiter_map[id]
                        if q: q.put(msg)
                    self._waiter_lock.release()
            else:
                # Notification message
                if msg.has_key('method') and msg.has_key('result') and self.on_rpc_callback :
                    self._async_call( lambda: self.on_rpc_callback(msg['method'], msg['result']) )
                
        except Exception, e:
            print( "_on_data_arrived:", e)
            pass
    

    def _send_hearbeat(self):
        msg = { 'jsonrpc' : '2.0',
                'method'  : '.sys.heartbeat',
                'params'  : { 'time': time.time() },
                'id'      : str(self.next_callid()) }
        json_str = self._pack(msg)
        self._send_request(json_str)

    def _alloc_wait_queue(self):
        self._waiter_lock.acquire()
        if self._call_wait_queue:
            q = self._call_wait_queue
            self._call_wait_queue = None
        else:
            q = Queue.Queue()
        self._waiter_lock.release()
        return q

    def _free_wait_queue(self, q):
        self._waiter_lock.acquire()
        if not self._call_wait_queue :
            self._call_wait_queue  = q
        else:
            del q
        self._waiter_lock.release()

    def call(self, method, params, timeout = 6) :
        #print "call", method, params, timeout
        callid = self.next_callid()
        if timeout:

            q = self._alloc_wait_queue()

            self._waiter_lock.acquire()
            self._waiter_map[callid] = q
            self._waiter_lock.release()
        
        msg = { 'jsonrpc' : '2.0',
                'method'  : method,
                'params'  : params,
                'id'      : str(callid) }

        #print "SEND", msg
        json_str = self._pack(msg)
        self._send_request(json_str)
        
        if timeout:
            ret = {}
            try:
                r = q.get(timeout = timeout)
                q.task_done()
            except Queue.Empty :
                r = None

            self._waiter_lock.acquire()
            self._waiter_map[callid] = None
            self._waiter_lock.release()
            self._free_wait_queue(q)

            if r:
                if r.has_key('result'):
                    ret['result'] = r['result']

                if r.has_key('error'):
                    ret['error'] = r['error']

            return ret if ret else { 'error': {'error': -1, 'message': "timeout"}}
        else:
            return { 'result': True }
